#!/usr/bin/env python
# -*-coding=utf-8-*-

import os
from tensorflow.contrib import learn
import tensorflow as tf
import numpy as np

class CnnModel():
    def __init__(self, checkpoint_dir):
        self.checkpoint_dir = checkpoint_dir
        self.vocab_processor = None

        self.graph = tf.Graph()
        self.checkpoint_file = tf.train.latest_checkpoint(self.checkpoint_dir)
        print(self.checkpoint_file)
        with self.graph.as_default():
            session_conf = tf.ConfigProto(
                allow_soft_placement = True,
                log_device_placement = False )
            self.sess = tf.Session(config=session_conf)
            saver = tf.train.import_meta_graph("{}.meta".format(self.checkpoint_file))
            saver.restore(self.sess, self.checkpoint_file)
            self.input_x = self.graph.get_operation_by_name("input_x").outputs[0]
            self.dropout_keep_prob = self.graph.get_operation_by_name("dropout_keep_prob").outputs[0]
            self.predictions = self.graph.get_operation_by_name("output/predictions").outputs[0]
            self.scores = self.graph.get_operation_by_name("output/predprob").outputs[0]
            # self.batches = train_utils.batch_iter(list(x_test), FLAGS.batch_size, 1, shuffle=False)

        self._load_vocab()


    def _load_vocab(self):
        vocab_path = os.path.join(self.checkpoint_dir, "..", "vocab")
        self.vocab_processor = learn.preprocessing.VocabularyProcessor.restore(vocab_path)


    def predict(self, text_list):
        input_text =  np.array(list(self.vocab_processor.transform(text_list)))
        all_predictions = []
        with self.sess.as_default():
            scores, batch_predictions = self.sess.run([self.scores, self.predictions], {self.input_x: input_text, self.dropout_keep_prob: 1.0})
            all_predictions = np.concatenate([all_predictions, batch_predictions])
            all_probs = np.max(scores, axis=1)
        return list(all_predictions), list(all_probs)


if __name__ == '__main__':

    model = CnnModel("./runs/1510117513/checkpoints/")
    text_list = ["你 把 电话 留 一下", "我 把电话 留 一下"]
    labels, probs = model.predict(text_list)
    print(labels)
    print(probs)
